"""add iteration_step table

Revision ID: c91513775753
Revises: 9bb7bb8b71c3
Create Date: 2022-10-14 14:50:55.795359

"""
from typing import Any, List, Optional

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "c91513775753"
down_revision = "9bb7bb8b71c3"
branch_labels = None
depends_on = None


def get_record(conn: Any, table: str, id_: Optional[int]) -> Any:
    if id_ is None:
        return None
    result = conn.execute(f"SELECT * FROM {table} WHERE id = {id_}").fetchall()
    return result[0] if result else None


def parse_steps_info_from_iteration(conn: Any, iteration: Any) -> List:
    """
    parse necessary information for each steps from iteration:

    - prepare_mining -> mining_input_dataset_id
    - mining -> mining_output_dataset_id
    - label -> label_output_dataset_id
    - prepare_training -> training_input_dataset_id
    - training -> training_output_model_id
    """
    prepare_mining_result = get_record(conn, "dataset", iteration.mining_input_dataset_id)
    mining_result = get_record(conn, "dataset", iteration.mining_output_dataset_id)
    label_result = get_record(conn, "dataset", iteration.label_output_dataset_id)
    prepare_training_result = get_record(conn, "dataset", iteration.training_input_dataset_id)
    training_result = get_record(conn, "model", iteration.training_output_model_id)

    steps = [
        {
            "name": "prepare_mining",
            "task_type": 11,
            "task_id": prepare_mining_result.task_id if prepare_mining_result else None,
        },
        {"name": "mining", "task_type": 2, "task_id": mining_result.task_id if mining_result else None},
        {"name": "label", "task_type": 3, "task_id": label_result.task_id if label_result else None},
        {
            "name": "prepare_training",
            "task_type": 11,
            "task_id": prepare_training_result.task_id if prepare_training_result else None,
        },
        {"name": "training", "task_type": 1, "task_id": training_result.task_id if training_result else None},
    ]
    for step in steps:
        step.update(
            {
                "iteration_id": iteration.id,
                "create_datetime": iteration.create_datetime,
                "update_datetime": iteration.update_datetime,
                "is_deleted": False,
                "is_finished": True,
            }
        )
    return steps


def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    iteration_step_table = op.create_table(
        "iteration_step",
        sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
        sa.Column("name", sa.String(length=100), nullable=False),
        sa.Column("iteration_id", sa.Integer(), nullable=False),
        sa.Column("task_type", sa.Integer(), nullable=False),
        sa.Column("task_id", sa.Integer(), nullable=True),
        sa.Column("serialized_presetting", sa.Text(length=20000), nullable=True),
        sa.Column("is_finished", sa.Boolean(), nullable=False),
        sa.Column("is_deleted", sa.Boolean(), nullable=False),
        sa.Column("create_datetime", sa.DateTime(), nullable=False),
        sa.Column("update_datetime", sa.DateTime(), nullable=False),
        sa.PrimaryKeyConstraint("id"),
    )
    with op.batch_alter_table("iteration_step", schema=None) as batch_op:
        batch_op.create_index(batch_op.f("ix_iteration_step_id"), ["id"], unique=False)
        batch_op.create_index(batch_op.f("ix_iteration_step_iteration_id"), ["iteration_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_iteration_step_name"), ["name"], unique=False)
        batch_op.create_index(batch_op.f("ix_iteration_step_task_id"), ["task_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_iteration_step_task_type"), ["task_type"], unique=False)

    conn = op.get_bind()
    try:
        # copy data from iteration table to new created iteration_step table
        iterations = conn.execute("SELECT * FROM iteration").fetchall()
        steps_data = [step for iteration in iterations for step in parse_steps_info_from_iteration(conn, iteration)]
        op.bulk_insert(iteration_step_table, steps_data)
    except Exception as e:
        print("Could not migrate iteration data to iteration_step, skip: %s" % e)
    # ### end Alembic commands ###


def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("iteration_step", schema=None) as batch_op:
        batch_op.drop_index(batch_op.f("ix_iteration_step_task_type"))
        batch_op.drop_index(batch_op.f("ix_iteration_step_task_id"))
        batch_op.drop_index(batch_op.f("ix_iteration_step_name"))
        batch_op.drop_index(batch_op.f("ix_iteration_step_iteration_id"))
        batch_op.drop_index(batch_op.f("ix_iteration_step_id"))

    op.drop_table("iteration_step")
    # ### end Alembic commands ###
